Tensor展开和torch.cat拼接原理 您所在的位置:网站首页 cat 原理 Tensor展开和torch.cat拼接原理

Tensor展开和torch.cat拼接原理

#Tensor展开和torch.cat拼接原理| 来源: 网络整理| 查看: 265

@ 大纲 tensor简介 tensor展开 torch.cat

最近用到torch.cat,需要搞明白tensor的展开原理和cat拼接技术,于是就有了这篇文章,仅为初学者阅读,大牛请略过,谢谢!

tensor

tensor 即张量,我理解为数据关系的一种表示,0阶张量是一个数字,也叫标量。1阶张量是一组数字,也叫向量。2阶张量多个向量组成,也叫矩阵。3阶张量,多个矩阵组成,构成一个立方体。4阶张量多个立方体组成。阶数再往上上可以自己发挥。话不多说,上图。 0-4阶张量标示图 深度学习里面,一般都是4阶张量,比如[16,128,28,28],batch size = 16,卷积后的图片大小是2828128。

tensor展开

如图中三阶张量展开图,模态1红色框,长方体的宽度。模态2浅蓝色框,长方体的长度,模态3是绿色框,长方体的深度。 三阶张量展开图 #torch.cat cat是行(axis = 0)或者列(axis=1)上进行拼接,其他维度不变。对于需要拼接的张量,维度数量必须相同,进行拼接的维度的尺寸可以不同,但是其它维度的尺寸必须相同。

import torch A = torch.tensor([[[1,13],[2,14],[3,15],[4,16]],[[5,17],[6,18],[7,19],[8,20]],[[9,21],[10,22],[11,23],[12,24]]]) print(A) print(A.size()) B = torch.tensor([[[1,13],[2,14],[3,15],[4,16]],[[5,17],[6,18],[7,19],[8,20]],[[9,21],[10,22],[11,23],[12,24]]]) print(B.size()) C = torch.cat((A,B),1) print(C) print(C.size()) tensor([[[ 1, 13], [ 2, 14], [ 3, 15], [ 4, 16]], [[ 5, 17], [ 6, 18], [ 7, 19], [ 8, 20]], [[ 9, 21], [10, 22], [11, 23], [12, 24]]]) torch.Size([3, 4, 2]) torch.Size([3, 4, 2]) tensor([[[ 1, 13], [ 2, 14], [ 3, 15], [ 4, 16], [ 1, 13], [ 2, 14], [ 3, 15], [ 4, 16]], [[ 5, 17], [ 6, 18], [ 7, 19], [ 8, 20], [ 5, 17], [ 6, 18], [ 7, 19], [ 8, 20]], [[ 9, 21], [10, 22], [11, 23], [12, 24], [ 9, 21], [10, 22], [11, 23], [12, 24]]]) torch.Size([3, 8, 2])

修改tensor B 为 [3,2,2],只能修改第二个维度,改别的会报错。则如下:

A = torch.tensor([[[1,13],[2,14],[3,15],[4,16]],[[5,17],[6,18],[7,19],[8,20]],[[9,21],[10,22],[11,23],[12,24]]]) print(A) print(A.size()) B = torch.tensor([[[1,13],[2,14]],[[5,17],[6,18]],[[9,21],[10,22]]]) print(B.size()) C = torch.cat((A,B),1) print(C) print(C.size())

输出

tensor([[[ 1, 13], [ 2, 14], [ 3, 15], [ 4, 16]], [[ 5, 17], [ 6, 18], [ 7, 19], [ 8, 20]], [[ 9, 21], [10, 22], [11, 23], [12, 24]]])

输出:

torch.Size([3, 4, 2]) torch.Size([3, 2, 2]) tensor([[[ 1, 13], [ 2, 14], [ 3, 15], [ 4, 16], [ 1, 13], [ 2, 14]], [[ 5, 17], [ 6, 18], [ 7, 19], [ 8, 20], [ 5, 17], [ 6, 18]], [[ 9, 21], [10, 22], [11, 23], [12, 24], [ 9, 21], [10, 22]]]) torch.Size([3, 6, 2]) Tensor拼接使用实例

GoogLeNet中得Inception模块,就使用了拼接。

class Inception(nn.Module): def __init__(self,input_channels,n1x1,n3x3_reduce,n3x3,n5x5_reduce,n5x5,pool_proj): super().__init__() #1x1conv branch self.b1 = nn.Sequential( nn.Conv2d(input_channels,n1x1,kernel_size=1), nn.BatchNorm2d(n1x1), nn.ReLU(inplace=True) ) #1x1conv -> 3x3conv branch self.b2 = nn.Sequential( nn.Conv2d(input_channels,n3x3_reduce,kernel_size=1), nn.BatchNorm2d(n3x3_reduce), nn.ReLU(inplace=True), nn.Conv2d(n3x3_reduce,n3x3,kernel_size=3,padding=1), nn.BatchNorm2d(n3x3), nn.ReLU(inplace=True) ) #1x1 -> 5x5conv branch # we use 2 3x3 conv filters instead of 1 5x5 conv filter # to obtain the same receptive and reduce numbers of parameters self.b3 = nn.Sequential( nn.Conv2d(input_channels,n5x5_reduce,kernel_size=1), nn.BatchNorm2d(n5x5_reduce), nn.ReLU(inplace=True), nn.Conv2d(n5x5_reduce,n5x5,kernel_size=3,padding=1), nn.BatchNorm2d(n5x5), nn.ReLU(inplace=True), nn.Conv2d(n5x5,n5x5,kernel_size=3,padding=1), nn.BatchNorm2d(n5x5), nn.ReLU(inplace=True) ) #3x3 pooling -> 1x1 conv self.b4 = nn.Sequential( nn.MaxPool2d(3,stride=1,padding=1,ceil_mode=True), nn.Conv2d(input_channels,pool_proj,kernel_size=1), nn.BatchNorm2d(pool_proj), nn.ReLU(inplace=True) ) def forward(self,x): b1 = self.b1(x) print('b1 size',b1.size()) b2 = self.b2(x) print('b2 size',b2.size()) b3 = self.b3(x) print('b3 size',b3.size()) b4 = self.b4(x) print('b4 size',b4.size()) output = [self.b1(x),self.b2(x),self.b3(x),self.b4(x)] print('after cat size',torch.cat(output,1).size()) return torch.cat(output,1)

输出结果是: 拼接以后的 torch.Size([16, 256, 28, 28])

b1 size torch.Size([16, 64, 28, 28]) b2 size torch.Size([16, 128, 28, 28]) b3 size torch.Size([16, 32, 28, 28]) b4 size torch.Size([16, 32, 28, 28]) after cat size torch.Size([16, 256, 28, 28]) b1 size torch.Size([16, 128, 28, 28]) b2 size torch.Size([16, 192, 28, 28]) b3 size torch.Size([16, 96, 28, 28]) b4 size torch.Size([16, 64, 28, 28]) after cat size torch.Size([16, 480, 28, 28]) b1 size torch.Size([16, 192, 14, 14]) b2 size torch.Size([16, 208, 14, 14]) b3 size torch.Size([16, 48, 14, 14]) b4 size torch.Size([16, 64, 14, 14]) after cat size torch.Size([16, 512, 14, 14]) b1 size torch.Size([16, 160, 14, 14]) b2 size torch.Size([16, 224, 14, 14]) b3 size torch.Size([16, 64, 14, 14]) b4 size torch.Size([16, 64, 14, 14]) after cat size torch.Size([16, 512, 14, 14]) b1 size torch.Size([16, 128, 14, 14]) b2 size torch.Size([16, 256, 14, 14]) b3 size torch.Size([16, 64, 14, 14]) b4 size torch.Size([16, 64, 14, 14]) after cat size torch.Size([16, 512, 14, 14]) b1 size torch.Size([16, 112, 14, 14]) b2 size torch.Size([16, 288, 14, 14]) b3 size torch.Size([16, 64, 14, 14]) b4 size torch.Size([16, 64, 14, 14]) after cat size torch.Size([16, 528, 14, 14]) b1 size torch.Size([16, 256, 14, 14]) b2 size torch.Size([16, 320, 14, 14]) b3 size torch.Size([16, 128, 14, 14]) b4 size torch.Size([16, 128, 14, 14]) after cat size torch.Size([16, 832, 14, 14]) b1 size torch.Size([16, 256, 7, 7]) b2 size torch.Size([16, 320, 7, 7]) b3 size torch.Size([16, 128, 7, 7]) b4 size torch.Size([16, 128, 7, 7]) after cat size torch.Size([16, 832, 7, 7]) b1 size torch.Size([16, 384, 7, 7]) b2 size torch.Size([16, 384, 7, 7]) b3 size torch.Size([16, 128, 7, 7]) b4 size torch.Size([16, 128, 7, 7]) after cat size torch.Size([16, 1024, 7, 7]) 参考文献

感谢各位大牛的指导 [1]: https://www5.in.tum.de/persons/huckle/tensor-kurs_1.pdf [2]: https://staffwww.dcs.shef.ac.uk/people/H.Lu/MSL/MSLbook-Chapter3.pdf [3]: http://www.cs.cornell.edu/cv/SummerSchool/Unfold.pdf



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有